# -*- coding: utf-8 -*-
# @Organization  : Alibaba XR-Lab
# @Author        : Lingteng Qiu
# @Email         : 220019047@link.cuhk.edu.cn
# @Time          : 2025-03-03 10:29:00
# @Function      : easy to use FaceSimilarity metric

import os
import pdb
import shutil
import sys

sys.path.append("./")
from collections import defaultdict

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from prettytable import PrettyTable
from torch.utils.data import Dataset
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchvision import transforms
from tqdm import tqdm
from tqlt import utils as tu

from openlrm.models.arcface_utils import ResNetArcFace
from openlrm.utils.face_detector import FaceDetector

device = "cuda"
model_path = "./pretrained_models/gagatracker/vgghead/vgg_heads_l.trcd"
face_detector = FaceDetector(model_path=model_path, device=device)

id_face_net = ResNetArcFace()
id_face_net.cuda()
id_face_net.eval()


def get_image_paths_current_dir(folder_path):
    image_extensions = {
        ".jpg",
        ".jpeg",
        ".png",
        ".gif",
        ".bmp",
        ".tiff",
        ".webp",
        ".jfif",
    }

    return sorted(
        [
            os.path.join(folder_path, f)
            for f in os.listdir(folder_path)
            if os.path.splitext(f)[1].lower() in image_extensions
        ]
    )


def write_json(path, x):
    """write a json file.

    Args:
        path (str): path to write json file.
        x (dict): dict to write.
    """
    import json

    with open(path, "w") as f:
        json.dump(x, f, indent=2)


def crop_face_image(image_path):
    rgb = np.array(Image.open(image_path))
    rgb = torch.from_numpy(rgb).permute(2, 0, 1)
    bbox = face_detector(rgb)
    head_rgb = rgb[:, int(bbox[1]) : int(bbox[3]), int(bbox[0]) : int(bbox[2])]
    head_rgb = head_rgb.permute(1, 2, 0)
    head_rgb = head_rgb.cpu().numpy()
    return head_rgb


def gray_resize_for_identity(out, size=128):
    out_gray = (
        0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]
    )
    out_gray = out_gray.unsqueeze(1)
    out_gray = F.interpolate(
        out_gray, (size, size), mode="bilinear", align_corners=False
    )
    return out_gray


@torch.no_grad()
def eval(input_folder, target_folder, device="cuda"):

    gt_imgs = get_image_paths_current_dir(target_folder)
    result_imgs = get_image_paths_current_dir(input_folder)

    gt_imgs = gt_imgs[::4]
    result_imgs = result_imgs[::4]

    if "visualization" in result_imgs[-1]:
        result_imgs = result_imgs[:-1]

    if len(gt_imgs) != len(result_imgs):
        return -1

    to_tensor = transforms.ToTensor()

    face_id_loss_list = []

    for input_img, gt_img in zip(result_imgs, gt_imgs):

        try:
            input_img = crop_face_image(input_img)
            input_head_tensor = gray_resize_for_identity(
                to_tensor(input_img).unsqueeze(0).to(device)
            )
            input_head_feature = id_face_net(input_head_tensor).detach()

            head_img = crop_face_image(gt_img)
            head_img = to_tensor(head_img).unsqueeze(0).to(device)
            src_head_tensor = gray_resize_for_identity(head_img)
            src_head_feature = id_face_net(src_head_tensor).detach()

            face_id_loss = F.l1_loss(input_head_feature, src_head_feature)
            face_id_loss_list.append(face_id_loss.item())
        except:
            continue

    if len(face_id_loss_list) > 0:

        return np.mean(face_id_loss_list)  # return max similarity view.
    else:
        return -1


def get_parse():
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-f1", "--folder1", type=str, required=True)
    parser.add_argument("-f2", "--folder2", type=str, required=True)
    parser.add_argument("--pad", action="store_true")
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()
    return args


if __name__ == "__main__":

    opt = get_parse()

    input_folder = opt.folder1
    target_folder = opt.folder2

    save_folder = os.path.join("./exps/metricsanigs", "psnr_results")
    os.makedirs(save_folder, exist_ok=True)

    input_folders = tu.next_folders(input_folder)

    results_dict = defaultdict(dict)

    results_dict = defaultdict(dict)
    face_similarity_list = []

    for input_folder in input_folders:

        item_basename = tu.basename(input_folder)

        mask_item_folder = None
        input_item_folder = os.path.join(input_folder, "rgb")
        target_item_folder = os.path.join(target_folder, item_basename)

        if os.path.exists(input_item_folder) and os.path.exists(target_item_folder):

            fs_ = eval(input_item_folder, target_item_folder)

            if fs_ == -1:
                continue

            face_similarity_list.append(fs_)

            results_dict[item_basename]["face_similarity"] = fs_
            if opt.debug:
                break
            print(results_dict)

    results_dict["all_mean"]["face_similarity"] = np.mean(face_similarity_list)

    write_json(os.path.join(save_folder, "face_similarity.json"), results_dict)
